Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gradient checkpointing in BERT #4659

Merged
merged 9 commits into from
Jun 22, 2020

Conversation

ibeltagy
Copy link
Contributor

@ibeltagy ibeltagy commented May 29, 2020

This PR adds support for gradient checkpointing in modeling_bert.py to save memory at training time at the expense of a slower backward pass. This is particularly useful if we want to pretrain a version of BERT for sequences longer than 512. It is also useful for long-document models like Longformer.

Stats:

Forward/backward - no grad checkpointing: 40.1GB memory, 25.3 seconds. 
Forward/backward - with grad checkpointing: 8.2GB memory (~5x less), 33.5 seconds (~1.3x more)
Forward pass only - with/without gradient checkpointing: 4GB memory, 6.1 seconds.

@codecov-commenter
Copy link

codecov-commenter commented May 29, 2020

Codecov Report

Merging #4659 into master will decrease coverage by 0.34%.
The diff coverage is 50.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4659      +/-   ##
==========================================
- Coverage   78.40%   78.06%   -0.35%     
==========================================
  Files         138      138              
  Lines       23757    23766       +9     
==========================================
- Hits        18627    18552      -75     
- Misses       5130     5214      +84     
Impacted Files Coverage Δ
src/transformers/modeling_bert.py 87.50% <44.44%> (-0.72%) ⬇️
src/transformers/configuration_bert.py 100.00% <100.00%> (ø)
src/transformers/data/processors/squad.py 34.07% <0.00%> (-22.62%) ⬇️
src/transformers/modeling_openai.py 79.51% <0.00%> (-1.39%) ⬇️
src/transformers/tokenization_bert.py 90.45% <0.00%> (-0.83%) ⬇️
src/transformers/tokenization_utils.py 94.81% <0.00%> (-0.38%) ⬇️
src/transformers/modeling_tf_utils.py 85.81% <0.00%> (-0.30%) ⬇️
src/transformers/modeling_utils.py 91.28% <0.00%> (+0.12%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f4e1f02...400070b. Read the comment docs.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a great addition, of which all models could benefit. Let's see what the rest of the team thinks and we'll look at upstreaming it in the transformers.PretrainedModel if everyone's on board.

@ibeltagy
Copy link
Contributor Author

ibeltagy commented May 29, 2020

we'll look at upstreaming it in the transformers.PretrainedModel if everyone's on board.

Thanks, @LysandreJik. It would be great to make gradient_checkpointing available to more models. While the configuration can be upstreamed in transformers.PretrainedConfig, the implementation is model specific, where you need to call torch.utils.checkpoint.checkpoint inside the layers loop as in here.

@LysandreJik
Copy link
Member

I was thinking of having the implementation be model agnostic as well. I haven't really thought out the best way, but a possible way to achieve it would be with a decorator; for example, in PretrainedModel we could have something like:

    @staticmethod
    def gradient_checkpointing(layer):
        @functools.wraps(layer)
        def wrapper(*args):
            layer_instance = args[0]
            # Remove the wrapper to prevent infinite recursion on the wrapper
            layer_instance.forward = functools.partial(layer_instance.forward.__wrapped__, layer_instance)
            
            if args[0].config.gradient_checkpointing:
                return torch.utils.checkpoint.checkpoint(layer_instance, *args[1:])
            else:
                return layer(*args)
        return wrapper

Then we can very simply add that decorator on the layers where we want the checkpoint:

class BertLayer(nn.Module):

    ...

    @PreTrainedModel.gradient_checkpointing
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):

    ...

This would require that these layers have access to the configuration so that they're aware of gradient check-pointing or not.

Pretty convenient, but pretty different from our coding style as well cc @thomwolf

@ibeltagy
Copy link
Contributor Author

neat

@minimaxir
Copy link

minimaxir commented Jun 1, 2020

A model agnostic approach might be best. In my research for isolating minimaxir/aitextgen#6 for finetuning larger GPT-2 models, it appeared that checkpointing would have to be implemented at the model level, as this PR does for BERT.

@ewrfcas
Copy link

ewrfcas commented Jun 3, 2020

torch.utils.checkpoint.checkpoint works well in single GPU. But it causes OOM in multi-gpu with torch.nn.DataParallel.

@patrickvonplaten
Copy link
Contributor

I was thinking of having the implementation be model agnostic as well. I haven't really thought out the best way, but a possible way to achieve it would be with a decorator; for example, in PretrainedModel we could have something like:

    @staticmethod
    def gradient_checkpointing(layer):
        @functools.wraps(layer)
        def wrapper(*args):
            layer_instance = args[0]
            # Remove the wrapper to prevent infinite recursion on the wrapper
            layer_instance.forward = functools.partial(layer_instance.forward.__wrapped__, layer_instance)
            
            if args[0].config.gradient_checkpointing:
                return torch.utils.checkpoint.checkpoint(layer_instance, *args[1:])
            else:
                return layer(*args)
        return wrapper

Then we can very simply add that decorator on the layers where we want the checkpoint:

class BertLayer(nn.Module):

    ...

    @PreTrainedModel.gradient_checkpointing
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):

    ...

This would require that these layers have access to the configuration so that they're aware of gradient check-pointing or not.

Pretty convenient, but pretty different from our coding style as well cc @thomwolf

I like idea of having a decorator function! Would it be enough to have this wrapper only at all "Model" forward functions, like BertModel.forward()?

@ibeltagy
Copy link
Contributor Author

ibeltagy commented Jun 3, 2020

torch.utils.checkpoint.checkpoint works well in single GPU. But it causes OOM in multi-gpu with torch.nn.DataParallel.

I haven't tried torch.nn.DataParallel but it works well with torch.nn.DistributedDataParallel on a single or multiple machines.

@ibeltagy
Copy link
Contributor Author

ibeltagy commented Jun 3, 2020

I like idea of having a decorator function! Would it be enough to have this wrapper only at all "Model" forward functions, like BertModel.forward()?

I don't think so. Even with the decorator, it is still model-specific; the decorator just makes the syntax easier. You still need to decide where to call it because too few calls (e.g. only on BertModel.forward), and you won't save enough memory, too many calls (e.g. on every .forward function) and the backward pass will be very slow.

@LysandreJik
Copy link
Member

Pinging @julien-c so he can take a look.

@ewrfcas
Copy link

ewrfcas commented Jun 4, 2020

torch.utils.checkpoint.checkpoint works well in single GPU. But it causes OOM in multi-gpu with torch.nn.DataParallel.

I haven't tried torch.nn.DataParallel but it works well with torch.nn.DistributedDataParallel on a single or multiple machines.

Thanks for the advice. But I try torch.nn.DistributedDataParallel and meet the same problem in pytorch/pytorch#24005. The pytorch version is 1.2.0.

The code is:

if n_gpu > 1:
     # model = torch.nn.DataParallel(model)
     torch.distributed.init_process_group(backend="nccl")
     model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)

Both find_unused_parameters=True and find_unused_parameters=False get errors.
image
image

@LysandreJik
Copy link
Member

@ibeltagy, after some back and forth offline with @julien-c and @thomwolf, the way you implemented it is preferred as it's simpler to understand and adheres better to the library's philosophy.

I think we can merge this and then in a following PR add it to all the other models. Would you like to take care of that? No worries if not, I can definitely take care of it.

@ibeltagy
Copy link
Contributor Author

@LysandreJik, glad this will be merged.

Would you like to take care of that? No worries if not, I can definitely take care of it.

I will pass :D

@schinger
Copy link

torch.utils.checkpoint.checkpoint works well in single GPU. But it causes OOM in multi-gpu with torch.nn.DataParallel.

I haven't tried torch.nn.DataParallel but it works well with torch.nn.DistributedDataParallel on a single or multiple machines.

Thanks for the advice. But I try torch.nn.DistributedDataParallel and meet the same problem in pytorch/pytorch#24005. The pytorch version is 1.2.0.

The code is:

if n_gpu > 1:
     # model = torch.nn.DataParallel(model)
     torch.distributed.init_process_group(backend="nccl")
     model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)

Both find_unused_parameters=True and find_unused_parameters=False get errors.
image
image

I encounter the same issue with torch 1.5.0 and latest transformers

@ibeltagy
Copy link
Contributor Author

@ewrfcas, @schinger, do you have a small example that reproduces the error?

I don't think we can fix this issue (needs a PyTorch fix pytorch/pytorch#24005), but I think we can work around it by removing the unused parameters mentioned in the error message.

@schinger
Copy link

schinger commented Jun 12, 2020

@ewrfcas, @schinger, do you have a small example that reproduces the error?

I don't think we can fix this issue (needs a PyTorch fix pytorch/pytorch#24005), but I think we can work around it by removing the unused parameters mentioned in the error message.

squad example training can reproduce this error: https://github.com/huggingface/transformers/tree/master/examples/question-answering

python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py
--model_type bert
--model_name_or_path bert-large-uncased-whole-word-masking
--do_train
--do_eval
--do_lower_case
--train_file SQUAD_DIR/dev-v1.1.json
--predict_file SQUAD_DIR/dev-v1.1.json
--learning_rate 3e-5
--num_train_epochs 2
--max_seq_length 384
--doc_stride 128
--output_dir ./examples/models/wwm_uncased_finetuned_squad/
--per_gpu_eval_batch_size=1
--per_gpu_train_batch_size=1 \

no matter find_unused_parameters is ture or false

@ibeltagy
Copy link
Contributor Author

Thanks. It would be more helpful if you provide a simpler and smaller example that I can easily run.

@schinger
Copy link

Thanks. It would be more helpful if you provide a simpler and smaller example that I can easily run.

you can change --train_file to SQUAD_DIR/dev-v1.1.json, dev set is small for quickly run

@schinger
Copy link

torch.utils.checkpoint.checkpoint works well in single GPU. But it causes OOM in multi-gpu with torch.nn.DataParallel.

I haven't tried torch.nn.DataParallel but it works well with torch.nn.DistributedDataParallel on a single or multiple machines.

could you show me a example about gradient checkpoint works successfully with torch.nn.DistributedDataParallel on multi-gpu?

@ewrfcas
Copy link

ewrfcas commented Jun 12, 2020

@ewrfcas, @schinger, do you have a small example that reproduces the error?

I don't think we can fix this issue (needs a PyTorch fix pytorch/pytorch#24005), but I think we can work around it by removing the unused parameters mentioned in the error message.

I have trained a base model instead of large to delay this problem.
The only differences in the code are

class BertEncoder(nn.Module):
     def forward(...):
        ...
        for i, layer_module in enumerate(self.layer):
            ...
            if self.use_grad_ckpt:
                layer_outputs = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask, head_mask[i],
                                           encoder_hidden_states, encoder_attention_mask)
            else:
                layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i],
                                             encoder_hidden_states, encoder_attention_mask)
            ...
       ...

and

torch.distributed.init_process_group(backend="nccl")
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)

Other codes are the same as normal finetuning codes.

@ibeltagy
Copy link
Contributor Author

ibeltagy commented Jun 12, 2020

Here's a small example to replicate the error

import os
import torch
from transformers import BertForPreTraining
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

model = BertForPreTraining.from_pretrained('bert-base-uncased', gradient_checkpointing=True).cuda()
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
outputs = model(torch.tensor([[1, 2, 3]]).cuda())
outputs[0].sum().backward()

Use find_unused_parameters=True and you will get

RuntimeError: Expected to mark a variable ready only once. This error is caused by use of a module parameter outside the `forward` function.

Use find_unused_parameters=False, and things will work just fine.

I couldn't replicate the other error,

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. 

@ewrfcas, do you know how to modify the example above to reproduce it?

@schinger, can you try find_unused_parameters=False see if it fixes your problem.

@yjernite yjernite mentioned this pull request Jun 12, 2020
@ewrfcas
Copy link

ewrfcas commented Jun 14, 2020

Here's a small example to replicate the error

import os
import torch
from transformers import BertForPreTraining
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

model = BertForPreTraining.from_pretrained('bert-base-uncased').cuda()
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
outputs = model(torch.tensor([[1, 2, 3]]).cuda())
outputs[0].sum().backward()

Use find_unused_parameters=True and you will get

RuntimeError: Expected to mark a variable ready only once. This error is caused by use of a module parameter outside the `forward` function.

Use find_unused_parameters=False, and things will work just fine.

I couldn't replicate the other error,

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. 

@ewrfcas, do you know how to modify the example above to reproduce it?

@schinger, can you try find_unused_parameters=False see if it fixes your problem.

I have tried this code. Although it works in the first, the second forword will be failed. You can try to repeat the loss.backward for few times.

@ibeltagy
Copy link
Contributor Author

ibeltagy commented Jun 14, 2020

@ewrfcas, I get this error with gradient_checkpointing=True and gradient_checkpointing=False (btw, gradient_checkpointing was set to False in the example above and I just updated it), so it is a problem with the example, not gradient checkpointing. Can you try to fix the example? or can you try it in a training loop that uses DDP correctly, either with pytorch-lightning or hugginface trainer?

@ewrfcas
Copy link

ewrfcas commented Jun 15, 2020

@ewrfcas, I get this error with gradient_checkpointing=True and gradient_checkpointing=False (btw, gradient_checkpointing was set to False in the example above and I just updated it), so it is a problem with the example, not gradient checkpointing. Can you try to fix the example? or can you try it in a training loop that uses DDP correctly, either with pytorch-lightning or hugginface trainer?

I have solved this problem by removing the self.pooler layer in BertModel because it did not forward any thing during the training. As the error saied, all parameters should be used in loss for DistributedDataParallel with find_unused_parameters=False, and find_unused_parameters=True is incompatible with gradient_checkpointing.

Maybe we need this code after the first backward

# check parameters with no grad
for n, p in model.named_parameters():
    if p.grad is None and p.requires_grad is True:
        print('No forward parameters:', n, p.shape)

@ibeltagy
Copy link
Contributor Author

ibeltagy commented Jun 15, 2020

Nice finding, @ewrfcas.

@LysandreJik, what is the best way to address this problem? do we need to fix it or can we leave it to the user to make sure all the model params are used? maybe in a separate PR we can find a way to automatically remove unused model params?

Also, aside from this issue, anything else we need to merge the PR?

@LysandreJik
Copy link
Member

Right, I think this should be looked at in a separate PR. Will take a final look and merge this PR tomorrow, and then look at implementing gradient checkpointing to the rest of the models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants